#  Mycroft Server - Backend
#  Copyright (c) 2022 Mycroft AI Inc
#  SPDX-License-Identifier: 	AGPL-3.0-or-later
#  #
#  This file is part of the Mycroft Server.
#  #
#  The Mycroft Server is free software: you can redistribute it and/or
#  modify it under the terms of the GNU Affero General Public License as
#  published by the Free Software Foundation, either version 3 of the
#  License, or (at your option) any later version.
#  #
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#  GNU Affero General Public License for more details.
#  #
#  You should have received a copy of the GNU Affero General Public License
#  along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""Defines data access methods for the metric.stt_transcription table."""
from dataclasses import asdict
from datetime import date
from decimal import Decimal
from typing import List

from ...repository_base import RepositoryBase
from ..entity.stt import SttTranscriptionMetric


class TranscriptionMetricRepository(RepositoryBase):
    """Defines data access methods for the metric.stt_transcription table."""

    def __init__(self, db):
        super().__init__(db, __file__)

    def add(self, metric: SttTranscriptionMetric) -> str:
        """Adds a row to the metric.stt_transcription table.

        :param metric: the metric to insert into the database
        :returns: the ID generated by the insert statement
        """
        sql_args = asdict(metric)
        sql_args.update(
            audio_duration=metric.audio_duration.quantize(Decimal("0.001")),
            transcription_duration=metric.transcription_duration.quantize(
                Decimal("0.001")
            ),
        )
        db_request = self._build_db_request(
            sql_file_name="add_tts_transcription_metric.sql", args=sql_args
        )
        db_result = self.cursor.insert_returning(db_request)

        return db_result["id"]

    def get_by_account(self, account_id: str) -> List[SttTranscriptionMetric]:
        """Get all the STT transcription metrics for an account.

        :param account_id: identifier of the account to use in the query
        :returns: query results
        """
        return self._select_all_into_dataclass(
            SttTranscriptionMetric,
            sql_file_name="get_tts_transcription_by_account.sql",
            args=dict(account_id=account_id),
        )

    def delete_by_date(self, transcription_date: date):
        """Delete all STT transcription metrics for a day.

        The data on the metric.stt_transcription table is attributable to and account.
        After aggregating the data for a day into the metric.stt_engine table, delete
        the attributable data to comply with our privacy policy.

        :param transcription_date: The date the transcription was requested
        """
        db_request = self._build_db_request(
            sql_file_name="add_tts_transcription_metric.sql",
            args=dict(transcription_date=transcription_date),
        )
        self.cursor.delete(db_request)
